{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "source": [
    "import torch\n",
    "from models import Showo, MAGVITv2\n",
    "from training.prompting_utils import UniversalPrompting, create_attention_mask_for_mmu, create_attention_mask_for_mmu_vit\n",
    "from training.utils import get_config, flatten_omega_conf, image_transform\n",
    "from transformers import AutoTokenizer\n",
    "from models.clip_encoder import CLIPVisionTower\n",
    "from transformers import CLIPImageProcessor\n",
    "import training.conversation as conversation_lib\n",
    "\n",
    "conversation_lib.default_conversation = conversation_lib.conv_templates[\"phi1.5\"]\n"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "source": [
    "# config load -  'showo_demo_w_clip_vit.yaml'\n",
    "from omegaconf import DictConfig, ListConfig, OmegaConf\n",
    "config = OmegaConf.load('configs/showo_demo_w_clip_vit.yaml')"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "source": [
    "# device setup\n",
    "device = torch.device(\"cuda:1\" if torch.cuda.is_available() else \"cpu\")\n",
    "# device = \"cpu\""
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "source": [
    "\n",
    "# show o tokenizer setup and adding special tokens to universal prompting\n",
    "# llm model : 'microsoft/phi-1_5'\n",
    "tokenizer = AutoTokenizer.from_pretrained(config.model.showo.llm_model_path, padding_side =\"left\")\n",
    "uni_prompting = UniversalPrompting(tokenizer, max_text_len=config.dataset.preprocessing.max_seq_length,\n",
    "                                       special_tokens=(\"<|soi|>\", \"<|eoi|>\", \"<|sov|>\", \"<|eov|>\", \"<|t2i|>\", \"<|mmu|>\", \"<|t2v|>\", \"<|v2v|>\", \"<|lvg|>\"),\n",
    "                                       ignore_id=-100, cond_dropout_prob=config.training.cond_dropout_prob)\n"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "source": [
    "# setting up the visual question answering model: magvit-v2\n",
    "vq_model = MAGVITv2\n",
    "vq_model = vq_model.from_pretrained(config.model.vq_model.vq_model_name).to(device)\n",
    "vq_model.requires_grad_(False)\n",
    "vq_model.eval()"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "source": [
    "# setting up vision tower: clip-vit\n",
    "vision_tower_name =\"openai/clip-vit-large-patch14-336\"\n",
    "vision_tower = CLIPVisionTower(vision_tower_name).to(device)\n",
    "clip_image_processor = CLIPImageProcessor.from_pretrained(vision_tower_name)\n"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "source": [
    "# setting up the showo model \n",
    "model = Showo.from_pretrained(config.model.showo.pretrained_model_path).to(device)\n",
    "model.eval()"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "source": [
    "# setting up the parameters\n",
    "temperature = 0.8  # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions\n",
    "top_k = 1  # retain only the top_k most likely tokens, clamp others to have 0 probability\n",
    "SYSTEM_PROMPT = \"A chat between a curious user and an artificial intelligence assistant. \" \\\n",
    "                \"The assistant gives helpful, detailed, and polite answers to the user's questions.\"\n",
    "SYSTEM_PROMPT_LEN = 28\n"
   ],
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Inference "
   ]
  },
  {
     "cell_type": "code",
     "execution_count": 1,
     "metadata": {},
     "source": [
       "import os\n",
       "import requests\n",
       "from IPython.display import Image\n",
       "from urllib.parse import urlparse\n",
       "\n",
       "def load_image(path_or_url, save_dir=\"downloaded_images\"):\n",
       "    \"\"\"Load image from local path or URL.\"\"\"\n",
       "    if os.path.exists(path_or_url):\n",
       "        return Image(filename=path_or_url)\n",
       "\n",
       "    os.makedirs(save_dir, exist_ok=True)\n",
       "    filename = os.path.join(save_dir, os.path.basename(urlparse(path_or_url).path))\n",
       "    \n",
       "    with requests.get(path_or_url, stream=True) as r:\n",
       "        if r.status_code == 200:\n",
       "            with open(filename, \"wb\") as f:\n",
       "                for chunk in r.iter_content(1024):\n",
       "                    f.write(chunk)\n",
       "            return Image(filename=filename)\n",
       "    \n",
       "    print(\"Failed to load image.\")\n",
       "    return None\n",
       "\n",
       "# Example usage\n",
       "image_path_or_url = \"/home/grads/h/hasnat.md.abdullah/Show-o/mmu_validation/sofa_under_water.jpg\"  # Or a URL\n",
       "load_image(image_path_or_url)"
     ],
     "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "source": [
    "# inference\n",
    "from PIL import Image\n",
    "## arguments\n",
    "input_image_path =\"./mmu_validation/sofa_under_water.jpg\"\n",
    "questions ='Please describe this image in detail. *** Do you think the image is unusual or not?'\n",
    "\n",
    "## processing\n",
    "questions = questions.split('***')\n",
    "image_ori = Image.open(input_image_path).convert(\"RGB\")\n",
    "# tranforming the image to the required resolution:256x256\n",
    "image = image_transform(image_ori, resolution = config.dataset.params.resolution).to(device)\n",
    "image = image.unsqueeze(0)\n",
    "print(f\"image shape: {image.shape}\") # torch.Size([1, 3, 256, 256])\n",
    "pixel_values = clip_image_processor.preprocess(image_ori,return_tensors=\"pt\")['pixel_values'][0]\n",
    "print(f\"pixel values shape: {pixel_values.shape}\")\n",
    "image_tokens = vq_model.get_code(image) + len(uni_prompting.text_tokenizer)\n",
    "print(f\"image tokens shape: {image_tokens.shape}\") # torch.Size([1, 256])\n",
    "batch_size = 1\n",
    "\n",
    "## inference\n",
    "for question in questions: \n",
    "  conv = conversation_lib.default_conversation.copy()\n",
    "  print(f\"conversation: {conv}\")\n",
    "  conv.append_message(conv.roles[0], question)\n",
    "  conv.append_message(conv.roles[1], None)\n",
    "  prompt_question = conv.get_prompt()\n",
    "  # print(prompt_question)\n",
    "  question_input = []\n",
    "  question_input.append(prompt_question.strip())\n",
    "  print(f\"system prompt: {SYSTEM_PROMPT}\")\n",
    "  input_ids_system = [uni_prompting.text_tokenizer(SYSTEM_PROMPT, return_tensors=\"pt\", padding=\"longest\").input_ids for _ in range(batch_size)]\n",
    "  print(f\"system prompt input ids: {input_ids_system}\")\n",
    "  input_ids_system = torch.stack(input_ids_system, dim=0)\n",
    "  assert input_ids_system.shape[-1] == 28\n",
    "  print(f\"after torch stacking: {input_ids_system}\")\n",
    "  input_ids_system = input_ids_system.clone().detach().to(device)\n",
    "  # inputs_ids_system = input_ids_system.to(device)\n",
    "#   inputs_ids_system = torch.tensor(input_ids_system).to(device).squeeze(0)\n",
    "  \n",
    "  print(f\"after moving to device: {input_ids_system}\")\n",
    "  input_ids_system = input_ids_system[0]\n",
    "  print(f\"after indexing 0: {input_ids_system}\")\n",
    "  \n",
    "  \n",
    "  print(f\"question input: {question_input}\")\n",
    "  input_ids = [uni_prompting.text_tokenizer(prompt, return_tensors=\"pt\", padding=\"longest\").input_ids for prompt in question_input]\n",
    "  print(f\"after tokenizing the question: {input_ids}\")\n",
    "  input_ids = torch.stack(input_ids)\n",
    "  print(f\"after torch stacking: {input_ids}\")\n",
    "  input_ids = torch.nn.utils.rnn.pad_sequence(\n",
    "                        input_ids, batch_first=True, padding_value=uni_prompting.text_tokenizer.pad_token_id\n",
    "                )\n",
    "  print(f\"after padding: {input_ids}\")\n",
    "  # input_ids = torch.tensor(input_ids).to(device).squeeze(0)\n",
    "  input_ids = input_ids.clone().detach().to(device).squeeze(0)\n",
    "  print(f\"after moving to device: {input_ids}\")\n",
    "  input_ids_llava = torch.cat([\n",
    "                          (torch.ones(input_ids.shape[0], 1) *uni_prompting.sptids_dict['<|mmu|>']).to(device),\n",
    "                          input_ids_system,\n",
    "                          (torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|soi|>']).to(device),\n",
    "                          # place your img embedding here\n",
    "                          (torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|eoi|>']).to(device),\n",
    "                          input_ids,\n",
    "                  ], dim=1).long()\n",
    "  print(input_ids_llava)\n",
    "  \n",
    "  images_embeddings = vision_tower(pixel_values[None])\n",
    "  print(f\"images embeddings shape: {images_embeddings.shape}\")# torch.Size([1, 576, 1024])\n",
    "  images_embeddings = model.mm_projector(images_embeddings)\n",
    "  print(f\"images embeddings shape after projection: {images_embeddings.shape}\") \n",
    "\n",
    "  text_embeddings = model.showo.model.embed_tokens(input_ids_llava)\n",
    "\n",
    "  #full input seq\n",
    "  part1 = text_embeddings[:, :2+SYSTEM_PROMPT_LEN,:]\n",
    "  part2 = text_embeddings[:, 2+SYSTEM_PROMPT_LEN:,:]\n",
    "  input_embeddings = torch.cat((part1,images_embeddings,part2),dim=1)\n",
    "\n",
    "  attention_mask_llava = create_attention_mask_for_mmu_vit(input_embeddings,system_prompt_len=SYSTEM_PROMPT_LEN)\n",
    "\n",
    "  cont_toks_list = model.mmu_generate(\n",
    "    input_embeddings = input_embeddings,\n",
    "    attention_mask = attention_mask_llava[0].unsqueeze(0),\n",
    "    max_new_tokens = 100,\n",
    "    top_k = top_k,\n",
    "    eot_token = uni_prompting.sptids_dict['<|eov|>']\n",
    "  )\n",
    "  \n",
    "  cont_toks_list = torch.stack(cont_toks_list).squeeze()[None]\n",
    "  text = uni_prompting.text_tokenizer.batch_decode(cont_toks_list,skip_special_tokens=True)\n",
    "  print(f\"User: {question}, \\nAnswer: {text[0]}\")\n",
    "\n",
    "\n"
   ],
   "outputs": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
